home *** CD-ROM | disk | FTP | other *** search
/ MacHack 2001 / MacHack 2001.toast / pc / The Hacks / TiVo™ for QuicktimeTV™ / RTSP-record.py < prev   
Encoding:
Python Source  |  2001-06-23  |  17.3 KB  |  646 lines

  1. #!/usr/local/bin/python
  2.  
  3. # This records RTSP streams.
  4. # It outputs to a file the following:
  5. # - canonical output for initial request
  6. #
  7.  
  8. """
  9.  
  10. RTSP Proxy v1.2
  11. ---------------
  12. Jonathan Hogg <jonathan@onegoodidea.com>
  13.  
  14. Copyright (c) 1999 One Good Idea Limited <http://www.onegoodidea.com/>
  15.  
  16. Permission to use, copy, modify, and distribute this software and its
  17. documentation for any purpose, without fee, and without a written agreement
  18. is hereby granted, provided that the above copyright notice and this
  19. paragraph and the following two paragraphs appear in all copies.
  20.  
  21. IN NO EVENT SHALL ONE GOOD IDEA LIMITED BE LIABLE TO ANY PARTY FOR DIRECT, 
  22. INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST 
  23. PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, 
  24. EVEN IF ONE GOOD IDEA LIMITED HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH 
  25. DAMAGE.
  26.  
  27. ONE GOOD IDEA LIMITED SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, 
  28. BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 
  29. FOR A PARTICULAR PURPOSE.  THE SOFTWARE PROVIDED HEREUNDER IS ON AN "AS 
  30. IS" BASIS, AND ONE GOOD IDEA LIMITED HAS NO OBLIGATIONS TO PROVIDE 
  31. MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS.
  32.  
  33.  
  34. Usage:
  35.  
  36.     % RTSP_Proxy
  37.  
  38.  
  39. The proxy listens on port 7070 so that it doesn't need to be run as root
  40. to operate (although this can be easily changed down the bottom of the
  41. script). It is a very simple program and can get confused, but in it's
  42. present state is about as functional as Apple's rtsp_proxy but a lot less
  43. buggy.
  44.  
  45. """
  46.  
  47.  
  48. import pickle
  49. import sys
  50. import string
  51. import re
  52. import threading
  53. import time
  54. from threading import *
  55.  
  56. from socket import *
  57. if not globals().has_key('IPPROTO_TCP'):
  58.     IPPROTO_TCP = 6
  59.  
  60. from select import *
  61.  
  62. import urlparse
  63. try:
  64.     if "rtsp" not in urlparse.uses_netloc:
  65.         urlparse.uses_netloc.append("rtsp")
  66. except:
  67.     pass
  68.  
  69.  
  70.  
  71. #------------------------------------------------------------------------
  72.  
  73. class Logger:
  74.  
  75.     def __init__( self, file = sys.stderr ):
  76.         self._lastmsg = ''
  77.         self._first = 1
  78.         self._repeats = 0
  79.         self._file = file
  80.         self._file.write( "[log started]" )
  81.         self._lock = Lock()
  82.     
  83.     def log( self, msg ):
  84.         self._lock.acquire()
  85.         if msg == self._lastmsg:
  86.             if self._repeats == 0:
  87.                 self._file.write( ' (.' )
  88.             self._file.write( '.' )
  89.             self._repeats = self._repeats + 1
  90.         else:
  91.             if self._repeats > 0:
  92.                 self._file.write( ')' )
  93.             self._file.write( '\n' )
  94.             self._first = 0
  95.             self._file.write( msg )
  96.             self._repeats = 0
  97.         self._file.flush()
  98.         if self._repeats == 75 - len(msg):
  99.             self._lastmsg = ''
  100.         else:
  101.             self._lastmsg = msg
  102.         self._lock.release()
  103.  
  104.  
  105. logger = Logger()
  106. debug = logger.log
  107.  
  108.  
  109. def makeportrange( ports ):
  110.  
  111.     if len(ports) == 1:
  112.         return "%d" % ports[0]
  113.     else:
  114.         return "%d-%d" % (ports[0], ports[-1])
  115.  
  116.  
  117.  
  118. #------------------------------------------------------------------------
  119.  
  120. class Message:
  121.  
  122.     def __init__( self, conn ):
  123.  
  124.         self._conn = conn
  125.         self._buffer = ""
  126.         self.readcommand()
  127.         self.readheaders()
  128.         self.readcontent()
  129.     debug("MESSAGE: " + self.getmessage())
  130.     
  131.     
  132.     def readdata( self ):
  133.  
  134.         self._buffer = self._buffer + self._conn.recv( 1024 )
  135.  
  136.  
  137.     def getdata( self, length):
  138.     
  139.         while 1:
  140.             if len(self._buffer) >= length:
  141.                 data = self._buffer[0:length]
  142.                 self._buffer = self._buffer[length:]
  143.                 return data
  144.             else:
  145.                 self.readdata()
  146.  
  147.  
  148.     def readline( self ):
  149.     
  150.         while 1:
  151.             if self._buffer == "":
  152.                 self.readdata()
  153.             
  154.             pos = string.find( self._buffer, "\r\n" )
  155.         
  156.             if pos <> -1:
  157.                 line = self._buffer[:pos]
  158.                 self._buffer = self._buffer[pos+2:]
  159.                 return line
  160.         
  161.             self.readdata()
  162.  
  163.  
  164.     def readcommand( self ):
  165.     
  166.         line = self.readline()
  167.         bits = string.split( line )
  168.         self._command = bits[0]
  169.         self._arguments = bits[1:]
  170.  
  171.  
  172.     def readheaders( self ):
  173.  
  174.         self._headerdict = {}
  175.         self._headerlist = []
  176.         
  177.         while 1:
  178.             line = self.readline()
  179.             if line == "":
  180.                 break
  181.             if line[0] in string.whitespace:
  182.                 header[1] = header[1] + string.lstrip(line)
  183.             else:
  184.                 (field,value) = string.split( line, ":", 1 )
  185.                 header = [field, string.strip(value)]
  186.                 self._headerlist.append( header )
  187.                 self._headerdict[string.lower(field)] = header
  188.     debug("HEADERS DICT: %s" % self._headerdict)
  189.  
  190.     def readcontent( self ):
  191.         
  192.         length = self.getheader('content-length')
  193.         if length:
  194.             self._content = self.getdata( int(length) )
  195.         else:
  196.             self._content = ""
  197.  
  198.  
  199.     def getmessage( self ):
  200.     
  201.         msg = self._command + " " + string.join( self._arguments ) + "\r\n"
  202.         
  203.         for header in self._headerlist:
  204.             msg = msg + "%s: %s\r\n" % (header[0], header[1])
  205.  
  206.         msg = msg + "\r\n" + self._content
  207.         
  208.         return msg
  209.         
  210.  
  211.     def getheader( self, field ):
  212.         
  213.         name = string.lower( field )
  214.         if self._headerdict.has_key( name ):
  215.             return self._headerdict[name][1]
  216.         else:
  217.             return None
  218.  
  219.  
  220.     def setheader( self, field, value ):
  221.         
  222.         self._headerdict[string.lower(field)][1] = value
  223.  
  224.  
  225.     def getcommand( self ):
  226.  
  227.         return self._command
  228.     
  229.     
  230.     def setcommand( self, command ):
  231.  
  232.         self._command = command
  233.     
  234.     
  235.     def getargs( self ):
  236.     
  237.         return self._arguments
  238.  
  239.  
  240.     def setargs( self, args ):
  241.  
  242.         self._arguments = args
  243.     
  244.     
  245. #------------------------------------------------------------------------
  246.  
  247. class SemaphoredWriteable:
  248.     def __init__(self, file):
  249.         self.file = file
  250.         self.semaphore = Semaphore()
  251.  
  252.     def write(self, data):
  253.         self.semaphore.acquire()
  254.         self.file.write(data)
  255.         self.file.flush()
  256.         self.semaphore.release()
  257.  
  258. #------------------------------------------------------------------------
  259.  
  260. class Session( Thread ):
  261.  
  262.     RTSP_PORT = 7070
  263.  
  264.  
  265.     def __init__( self, conn, addr ):
  266.     
  267.         Thread.__init__( self )
  268.         self._clientconn = conn
  269.         self._clientaddr = addr
  270.         self._serverconn = None
  271.         self._serveraddr = ''
  272.     filename = time.strftime("%Y-%m-%d-%H:%M", time.localtime(time.time())) + ".save"
  273.     self._output = SemaphoredWriteable(open(filename, "w"))
  274.     self._startTime = time.time()
  275.     self.setDaemon( 1 )
  276.  
  277.     def archiveMessageReceivedFromClient(self, message):
  278.     ob = (time.time(), "messageFromClient", message)
  279.     p = pickle.dumps(ob)
  280.     self._output.write(p)
  281.  
  282.     def archiveMessageSentToClient(self, message):
  283.     ob = (time.time(), "messageToClient", message)
  284.     p = pickle.dumps(ob)
  285.     self._output.write(p)
  286.  
  287.     def sendclientmsg( self, msg ):
  288.     m = msg.getmessage()
  289.         self._clientconn.send( m )
  290.     self.archiveMessageSentToClient(m)
  291.  
  292.     
  293.     def getservermsg( self ):
  294.         m = Message( self._serverconn )
  295.     return m
  296.  
  297.  
  298.     def sendservermsg( self, msg ):
  299.         self._serverconn.send( msg.getmessage() )
  300.  
  301.  
  302.     def dispatch( self, msg ):
  303.  
  304.         self.archiveMessageReceivedFromClient(msg.getmessage())
  305.  
  306.         command = msg.getcommand()
  307.         
  308.         debug( "got command: " + command )
  309.         
  310.         if command == "DESCRIBE":
  311.             self.do_describe( msg )
  312.             
  313.         elif command == "SETUP":
  314.             self.do_setup( msg )
  315.             
  316.         elif command == "OPTIONS":
  317.             self.do_options( msg )
  318.  
  319.         else:
  320.             self.sendservermsg( msg )
  321.             response = self.getservermsg()
  322.             self.sendclientmsg( response )
  323.     
  324.     
  325.     def do_options( self, msg ):
  326.     
  327.         if self._client_type[:4] == 'QTS/' and self._server_type == 'QTSS/v66':
  328.             debug( '  translating OPTIONS into a GET_PARAMETER ping for broken QuickTime' )
  329.             msg.setcommand( 'GET_PARAMETER' )
  330.             self.sendservermsg( msg )
  331.             
  332.             msg.setcommand( 'RTSP/1.0' )
  333.             msg.setargs( ['200', 'OK'] )
  334.             self.sendclientmsg( msg )
  335.         
  336.         else:
  337.             self.sendservermsg( msg )
  338.             response = self.getservermsg()
  339.             self.sendclientmsg( response )
  340.  
  341.  
  342.     def do_describe( self, msg ):
  343.     
  344.         url = msg.getargs()[0]
  345.         parsed_url = urlparse.urlparse( url, "rtsp" )
  346.         site = parsed_url[1]
  347.         self._client_type = msg.getheader('user-agent')
  348.         debug( "  client is a: %s" % self._client_type )
  349.  
  350.         pos = string.find( site, ":" )
  351.         if pos >= 0:
  352.             addr,port = string.split( site, ":" )
  353.             port = int( port )
  354.         else:
  355.             addr = site
  356.             port = self.RTSP_PORT
  357.         
  358.         if not self._serverconn:
  359.             debug( "  trying connection to %s:%d" % (addr,port) )
  360.             sock = socket( AF_INET, SOCK_STREAM )
  361.             sock.connect( (addr,port) )
  362.             self._serverconn = sock
  363.             self._serveraddr = addr
  364.         
  365.         self.sendservermsg( msg )
  366.         response = self.getservermsg()
  367.         self._server_type = response.getheader('server')
  368.         debug( "  server is a: %s" % self._server_type )
  369.         self.sendclientmsg( response )
  370.  
  371.  
  372.     def do_setup( self, msg ):
  373.     
  374.         client_port = ''
  375.         debug( "  client requests of proxy:\n    %s" % msg.getheader('transport') )
  376.         for bit in string.split( msg.getheader('transport'), ";" ):
  377.             bit = string.strip( bit )
  378.             
  379.             if string.find( bit, '=' ) > 0:
  380.                 name, value = string.split( bit, '=', 1 )
  381.             
  382.                 if name == 'client_port':
  383.                     client_port = value
  384.  
  385.         if string.find( client_port, "-" ):
  386.             startport,endport = string.split( client_port, "-" )
  387.             clientports = range( int(startport), int(endport) + 1 )
  388.         else:
  389.             clientports = [ int(client_port) ]
  390.  
  391.     # find the trackID
  392.     URI = msg.getargs()[0]
  393.     print "ARGS->", msg.getargs()
  394.     trackID = 0
  395.     m = re.search("trackID=(\d+)", URI)
  396.     if m != None: trackID = int(m.group(1))
  397.     print "track ID = ", trackID
  398.  
  399.         proxy = Forwarder( self._clientaddr, clientports, trackID, self._output )
  400.  
  401.         msg.setheader( 'transport', 'RTP/AVP;unicast;client_port=' + proxy.getportrange() )
  402.  
  403.         debug( "  proxy requests of server:\n    " + str(msg.getheader('transport') ))
  404.         
  405.         self.sendservermsg( msg )
  406.         response = self.getservermsg()
  407.         
  408.         server_port = ''
  409.         source = ''
  410.         
  411.         debug( "  server offers to proxy:\n    " + str(response.getheader('transport') ))
  412.         
  413.         for bit in string.split( response.getheader('transport'), ";" ):
  414.             bit = string.strip( bit )
  415.             
  416.             if string.find( bit, '=' ) > 0:
  417.                 name, value = string.split( bit, '=', 1 )
  418.             
  419.                 if name == 'server_port':
  420.                     server_port = value
  421.                 
  422.                 elif name == 'source':
  423.                     source = value
  424.         
  425.         if string.find( server_port, "-" ):
  426.             startport,endport = string.split( server_port, "-" )
  427.             serverports = range( int(startport), int(endport) + 1 )
  428.         else:
  429.             serverports = [ int(server_port) ]
  430.         
  431.         
  432.         if source <> '':
  433.             addr = source
  434.         else:
  435.             addr = self._serveraddr
  436.         
  437.         proxy.setserver( addr, serverports )
  438.         proxy.start()
  439.         
  440.         response.setheader( 'transport',
  441.                             'RTP/AVP;unicast;client_port=%s;server_port=%s' % (client_port,
  442.                                 proxy.getportrange()) )
  443.  
  444.         debug( "  proxy offers to client:\n    " + response.getheader('transport') )
  445.         
  446.         self.sendclientmsg( response )
  447.  
  448.  
  449.     def run( self ):
  450.     
  451. #        try:
  452.             while 1:
  453.                 msg = Message( self._clientconn )
  454.                 self.dispatch( msg )
  455.         
  456. #        except:
  457.             debug( "taking down session" )
  458.             self._clientconn.close()
  459.             if self._serverconn:
  460.                 self._serverconn.close()
  461.  
  462.  
  463.  
  464. #------------------------------------------------------------------------
  465.  
  466. class Listener:
  467.  
  468.  
  469.     def __init__( self, port ):
  470.     
  471.         self._sock = socket( AF_INET, SOCK_STREAM )
  472.         self._sock.bind( ('',port) )
  473.         self._sock.setsockopt( IPPROTO_TCP, SO_REUSEADDR, 1 )
  474.         self._sock.listen( 5 )
  475.  
  476.  
  477.     def waitforclient( self ):
  478.     
  479.         conn, addr = self._sock.accept()
  480.         debug( "accepted connection from %s:%d" % addr )
  481.         return Session( conn, addr[0] )
  482.  
  483.  
  484.     def stop( self ):
  485.         self._sock.close()
  486.  
  487.  
  488.  
  489. #------------------------------------------------------------------------
  490.  
  491. class Forwarder( Thread ):
  492.  
  493.     START_PORT = 10000
  494.     _currentport = START_PORT
  495.     
  496.  
  497.     def __init__( self, addr, ports, trackID, archive ):
  498.     
  499.         Thread.__init__( self )
  500.         self._clientaddr = gethostbyname( addr )
  501.         self._clientports = ports
  502.     self._archive = archive
  503.     self._trackID = trackID
  504.         self.buildports()
  505.     self.setDaemon( 1 )
  506.     
  507.     
  508.     def _allocateports( self, howmany ):
  509.     
  510.         start = Forwarder._currentport
  511.         sofar = 0
  512.         socks = []
  513.         
  514.         while sofar < howmany:
  515.         
  516.             sock = socket( AF_INET, SOCK_DGRAM )
  517.             port = Forwarder._currentport
  518.             Forwarder._currentport = Forwarder._currentport + 1
  519.             
  520.             try:
  521.                 sock.bind( ('',port) )
  522.             except:
  523.                 sofar = 0
  524.                 start = self._currentport
  525.                 socks = []
  526.                 
  527.             socks.append( (port,sock) )
  528.             sofar = sofar + 1
  529.             end = port
  530.         
  531.         debug( "  allocated a port range at %d-%d" % (start,end) )
  532.  
  533.         return socks
  534.  
  535.  
  536.     def setserver( self, addr, ports ):
  537.     
  538.         self._serveraddr = gethostbyname( addr )
  539.         self._serverports = ports
  540.  
  541.     
  542.     def getportrange( self ):
  543.     
  544.         ports = map( lambda x: x[0], self._proxysocks )
  545.         return makeportrange( ports )
  546.  
  547.     def archiveData(self, trackID, index, data):
  548.     ob = (time.time(), "data", trackID, index, data)
  549.     p = pickle.dumps(ob, 1) # dump as binary
  550.     self._archive.write(p)
  551.  
  552.     def doforwarding( self ):
  553.     
  554.         sockets = map( lambda x: x[1], self._proxysocks )
  555.         count = 0
  556.         size = 0
  557.         then = time.time()
  558.     
  559.         while 1:
  560.             readylist, _, _ = select( sockets, [], [] )
  561.             
  562.             for i in range(len(self._proxysocks)):
  563.             
  564.                 port, sock = self._proxysocks[i]
  565.                 clientport = self._clientports[i]
  566.                 serverport = self._serverports[i]
  567.                 
  568.                 if sock in readylist:
  569.                     (packet, addr) = sock.recvfrom( 2048 )
  570.                     
  571.                     if addr == (self._serveraddr, serverport):
  572.                     self.archiveData(self._trackID, i, packet)
  573.                         count = count + 1
  574.                         size = size + len(packet)
  575.                         sock.sendto( packet, (self._clientaddr, clientport) )
  576.              debug(("sending  %d bytes to port " % len(packet)) + str(clientport))
  577.                     
  578.                         now = time.time()
  579.                         elapsed = now - then
  580.                         if elapsed > 30.0:
  581.                             rate = (size * 8) / elapsed
  582.                             debug( "forwarding rate approx %dbps on ports %s" % (rate,
  583.                                         self.getportrange()) )
  584.                             count = 0
  585.                             size = 0
  586.                             then = now
  587.                             
  588.                     elif addr == (self._clientaddr, clientport):
  589.                         debug( "forwarding from client" )
  590.                         sock.sendto( packet, (self._serveraddr, serverport) )
  591.                     
  592.                     elif addr[1] == serverport:
  593.                         debug( "server %s lied about its' address, it's really %s" % (
  594.                             self._serveraddr, addr[0]) )
  595.                         self._serveraddr = addr[0]
  596.                         sock.sendto( packet, (self._clientaddr, clientport) )
  597.                     
  598.                     else:
  599.                         debug( "forwarder received packet from unexpected source, %s:%d" % addr )
  600.             
  601.  
  602.  
  603.     def run( self ):
  604.     
  605.         debug( "  starting an RTP forwarder on ports " + self.getportrange() )
  606.         debug( "    server %s:%s  <->  client %s:%s" % (
  607.             self._serveraddr, makeportrange(self._serverports),
  608.             self._clientaddr, makeportrange(self._clientports)))
  609.  
  610. #    try:
  611.     self.doforwarding()
  612.  
  613. #    except:
  614.     debug( "stopping forwarder on ports " + self.getportrange() )
  615.         
  616.     
  617.     def buildports( self ):
  618.     
  619.         num = len( self._clientports )
  620.         self._proxysocks = self._allocateports( num )
  621.  
  622.  
  623.  
  624. #------------------------------------------------------------------------
  625.  
  626. def main( argv ):
  627.  
  628.     listener = Listener( 7272 )
  629.     
  630.     debug( "waiting for a client" )
  631.     
  632.     try:
  633.         while 1:
  634.             client = listener.waitforclient()
  635.             client.start()
  636.  
  637.     finally:
  638.     debug("stopping listener")
  639.         listener.stop()
  640.  
  641.  
  642. if __name__ == "__main__":
  643.     main( sys.argv )
  644.  
  645.  
  646.